import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import timm

from utils import FORMAT_INFO, to_device
from tokenizer import SOS_ID, EOS_ID, PAD_ID, MASK_ID
from searchStrategy.GreedySearch import GreedySearch
from model.Transformer import TransformerDecoder
from model.Embedding import Embeddings
from model.SwinTransformer import SwinTransformer



class Encoder(nn.Module):

    def __init__(self, args, pretrained=False):

        super().__init__() 
        model_name = args.encoder
        self.model_name = model_name

        if model_name.startswith('swin'):
            self.model_type = 'swin'
            self.transformer = SwinTransformer()
            self.n_features = self.transformer.num_features
            self.transformer.head = nn.Identity()
        else:
            raise NotImplemented

    def forward(self, x, refs=None):

        if self.model_type == 'swin':
            features, hiddens = self.transformer(x)

        else:
            raise NotImplemented
        
        return features, hiddens


class TransformerDecoderBase(nn.Module):

    def __init__(self, args):

        super().__init__()
        self.args = args

        self.enc_trans_layer = nn.Sequential(
            nn.Linear(args.encoder_dim, args.dec_hidden_size)
        )

        self.enc_pos_emb = nn.Embedding(144, args.encoder_dim) if args.enc_pos_emb else None

        self.decoder = TransformerDecoder(
            num_layers=args.dec_num_layers,
            d_model=args.dec_hidden_size,
            heads=args.dec_attn_heads,
            d_ff=args.dec_hidden_size * 4,
            copy_attn=False,
            self_attn_type="scaled-dot",
            dropout=args.hidden_dropout,
            attention_dropout=args.attn_dropout,
            max_relative_positions=args.max_relative_positions,
            aan_useffn=False,
            full_context_alignment=False,
            alignment_layer=0,
            alignment_heads=0,
            pos_ffn_activation_fn='gelu'
        )


    def enc_transform(self, encoder_out):

        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)
        max_len = encoder_out.size(1)
        device = encoder_out.device

        if self.enc_pos_emb:

            pos_emb = self.enc_pos_emb(torch.arange(max_len, device=device)).unsqueeze(0)
            encoder_out = encoder_out + pos_emb

        encoder_out = self.enc_trans_layer(encoder_out)

        return encoder_out


class TransformerDecoderAR(TransformerDecoderBase):

    def __init__(self, args, tokenizer):

        super().__init__(args)
        self.tokenizer = tokenizer

        self.vocab_size = len(self.tokenizer)

        self.output_layer = nn.Linear(args.dec_hidden_size, self.vocab_size, bias=True)

        self.embeddings = Embeddings(
            word_vec_size=args.dec_hidden_size,
            word_vocab_size=self.vocab_size,
            word_padding_idx=PAD_ID,
            position_encoding=True,
            dropout=args.hidden_dropout)
        
   
    def dec_embedding(self, tgt, step=None):

        pad_idx = self.embeddings.word_padding_idx

        tgt_pad_mask = tgt.data.eq(pad_idx).transpose(1, 2)  
        emb = self.embeddings(tgt, step=step)
        assert emb.dim() == 3  

        return emb, tgt_pad_mask

    def forward(self, encoder_out, labels, label_lengths):

        batch_size, max_len, _ = encoder_out.size() 

        memory_bank = self.enc_transform(encoder_out)

        tgt = labels.unsqueeze(-1)  
        tgt_emb, tgt_pad_mask = self.dec_embedding(tgt)
        dec_out, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank, tgt_pad_mask=tgt_pad_mask)

        logits = self.output_layer(dec_out)  

        return logits[:, :-1], labels[:, 1:], dec_out

    def decode(self, encoder_out, beam_size: int, n_best: int, min_length: int = 1, max_length: int = 256,
               labels=None):

        batch_size, max_len, _ = encoder_out.size()
        memory_bank = self.enc_transform(encoder_out)


        if beam_size == 1:
            decode_strategy = GreedySearch(
                sampling_temp=0.0, keep_topk=1, batch_size=batch_size, min_length=min_length, max_length=max_length,
                pad=PAD_ID, bos=SOS_ID, eos=EOS_ID,
                return_attention=False, return_hidden=True)
            

        results = {
            "predictions": None,
            "scores": None,
            "attention": None
        }

        _, memory_bank = decode_strategy.initialize(memory_bank=memory_bank)

        for step in range(decode_strategy.max_length):

            tgt = decode_strategy.current_predictions.view(-1, 1, 1)
            tgt_emb, tgt_pad_mask = self.dec_embedding(tgt)

            dec_out, dec_attn, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank,
                                                 tgt_pad_mask=tgt_pad_mask, step=step)
            # print(dec_out.shape)
            # print("================memory_bank====================")
            # print(dec_out)
            # print("================memory_bank====================")
            # print(dec_attn)
            # print("================attn====================")
            attn = dec_attn.get("std", None)

            dec_logits = self.output_layer(dec_out)  
            dec_logits = dec_logits.squeeze(1)       
            log_probs = F.log_softmax(dec_logits, dim=-1)
            label = labels[:, step + 1] if labels is not None and step + 1 < labels.size(1) else None

            decode_strategy.advance(log_probs, attn, dec_out, label)  
            any_finished = decode_strategy.is_finished.any()

            if any_finished:
                decode_strategy.update_finished()

                if decode_strategy.done:
                    break

            select_indices = decode_strategy.select_indices

            if any_finished:
                memory_bank = memory_bank.index_select(0, select_indices)
                if labels is not None:
                    labels = labels.index_select(0, select_indices)

                self.map_state(lambda state, dim: state.index_select(dim, select_indices))
        
        results["scores"] = decode_strategy.scores 
        results["token_scores"] = decode_strategy.token_scores
        results["predictions"] = decode_strategy.predictions
        results["attention"] = decode_strategy.attention
        results["hidden"] = decode_strategy.hidden
        


        return results["predictions"], results['scores'], results["token_scores"], results["hidden"]

    def map_state(self, fn):

        def _recursive_map(struct, batch_dim=0):

            for k, v in struct.items():

                if v is not None:

                    if isinstance(v, dict):
                        _recursive_map(v)

                    else:
                        struct[k] = fn(v, batch_dim)

        if self.decoder.state["cache"] is not None:
            _recursive_map(self.decoder.state["cache"])


class GraphPredictor(nn.Module):

    def __init__(self, decoder_dim, coords=False):

        super(GraphPredictor, self).__init__()
        self.coords = coords

        self.mlp = nn.Sequential(
            nn.Linear(decoder_dim * 2, decoder_dim), nn.GELU(),
            nn.Linear(decoder_dim, 7)
        )

        if coords:
            self.coords_mlp = nn.Sequential(
                nn.Linear(decoder_dim, decoder_dim), nn.GELU(),
                nn.Linear(decoder_dim, 2)
            )

    def forward(self, hidden, indices=None):

        results = {}
        b, l, dim = hidden.size()

        if indices is None:
            index = [i for i in range(3, l, 3)]
            hidden = hidden[:, index]

        else:

            batch_id = torch.arange(b).unsqueeze(1).expand_as(indices).reshape(-1)
            indices = indices.view(-1)
            hidden = hidden[batch_id, indices].view(b, -1, dim)

        b, l, dim = hidden.size()
        hh = torch.cat([hidden.unsqueeze(2).expand(b, l, l, dim), hidden.unsqueeze(1).expand(b, l, l, dim)], dim=3)
        results['edges'] = self.mlp(hh).permute(0, 3, 1, 2)

        if self.coords:
            results['coords'] = self.coords_mlp(hidden)

        return results


def get_edge_prediction(edge_prob):
    if not edge_prob:
        return [], []
    n = len(edge_prob)
    if n == 0:
        return [], []
    for i in range(n):
        for j in range(i + 1, n):
            for k in range(5):
                edge_prob[i][j][k] = (edge_prob[i][j][k] + edge_prob[j][i][k]) / 2
                edge_prob[j][i][k] = edge_prob[i][j][k]
            edge_prob[i][j][5] = (edge_prob[i][j][5] + edge_prob[j][i][6]) / 2
            edge_prob[i][j][6] = (edge_prob[i][j][6] + edge_prob[j][i][5]) / 2
            edge_prob[j][i][5] = edge_prob[i][j][6]
            edge_prob[j][i][6] = edge_prob[i][j][5]
    prediction = np.argmax(edge_prob, axis=2).tolist()
    score = np.max(edge_prob, axis=2).tolist()
    return prediction, score



class Decoder(nn.Module):

    def __init__(self, args, tokenizer):
        super(Decoder, self).__init__()
        self.args = args
        self.formats = args.formats
        self.tokenizer = tokenizer

        decoder = {}   
        for format_ in args.formats:

            if format_ == 'edges':
                decoder['edges'] = GraphPredictor(args.dec_hidden_size, coords=args.continuous_coords)

            else:
                decoder[format_] = TransformerDecoderAR(args, tokenizer[format_])

        self.decoder = nn.ModuleDict(decoder)
        self.compute_confidence = args.compute_confidence

    def forward(self, encoder_out, hiddens, refs):

        results = {}
        refs = to_device(refs, encoder_out.device)
        for format_ in self.formats:

            if format_ == 'edges':

                if 'chartok_coords' in results:
                    dec_out = results['chartok_coords'][2]
                    predictions = self.decoder['edges'](dec_out, indices=refs['atom_indices'][0])

                targets = {'edges': refs['edges']}
                if 'coords' in predictions:
                    targets['coords'] = refs['coords']
                
                results['edges'] = (predictions, targets)

            else:
                labels, label_lengths = refs[format_]
                results[format_] = self.decoder[format_](encoder_out, labels, label_lengths)

        return results

    def decode(self, encoder_out, hiddens=None, refs=None, beam_size=1, n_best=1):

        results = {}
        predictions = []
        for format_ in self.formats:

            if format_ in ['atomtok', 'atomtok_coords', 'chartok_coords']:

                max_len = FORMAT_INFO[format_]['max_len']
                results[format_] = self.decoder[format_].decode(encoder_out, beam_size, n_best, max_length=max_len)
                outputs, scores, token_scores, *_ = results[format_]

                beam_preds = [[self.tokenizer[format_].sequence_to_smiles(x.tolist()) for x in pred]
                              for pred in outputs]
                
      
                predictions = [{format_: pred[0]} for pred in beam_preds]

                if self.compute_confidence:

                    for i in range(len(predictions)):
                        indices = np.array(predictions[i][format_]['indices']) - 3

                        if format_ == 'chartok_coords':

                            atom_scores = []
                            for symbol, index in zip(predictions[i][format_]['symbols'], indices):
  
                                atom_score = (np.prod(token_scores[i][0][index - len(symbol) + 1:index + 1])
                                              ** (1 / len(symbol))).item()
                                atom_scores.append(atom_score)
                        else:
                            atom_scores = np.array(token_scores[i][0])[indices].tolist()

                        predictions[i][format_]['atom_scores'] = atom_scores
                        predictions[i][format_]['average_token_score'] = scores[i][0]
            if format_ == 'edges':
                if 'chartok_coords' in results:
                    atom_format = 'chartok_coords'

                else:
                    raise NotImplemented
                
                dec_out = results[atom_format][3]  
                for i in range(len(dec_out)):

                    hidden = dec_out[i][0].unsqueeze(0)  
                    indices = torch.LongTensor(predictions[i][atom_format]['indices']).unsqueeze(0) 
                    pred = self.decoder['edges'](hidden, indices)  

                    prob = F.softmax(pred['edges'].squeeze(0).permute(1, 2, 0), dim=2).tolist() 
                    edge_pred, edge_score = get_edge_prediction(prob)
                    predictions[i]['edges'] = edge_pred

                    if self.compute_confidence:
                        predictions[i]['edge_scores'] = edge_score
                        predictions[i]['edge_score_product'] = np.sqrt(np.prod(edge_score)).item()
                        predictions[i]['overall_score'] = predictions[i][atom_format]['average_token_score'] * \
                                                          predictions[i]['edge_score_product']
                        
                        predictions[i][atom_format].pop('average_token_score')
                        predictions[i].pop('edge_score_product')
        # print("=======predictions========")
        # print(predictions)
        # print("=======predictions========")
        return predictions
